{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyNPGJZDkFcJp9nxm0mrt4On",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/nvidia_neo4j_langchain.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PBS17Sk9iXcH",
        "outputId": "629d4a5d-a3f8-469f-dc9c-196726f9576c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.4/50.4 kB\u001b[0m \u001b[31m1.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.3/2.3 MB\u001b[0m \u001b[31m25.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m293.7/293.7 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m379.8/379.8 kB\u001b[0m \u001b[31m14.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m990.6/990.6 kB\u001b[0m \u001b[31m28.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.1/140.1 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.5/4.5 MB\u001b[0m \u001b[31m50.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.2/49.2 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m141.1/141.1 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install --upgrade --quiet langchain-nvidia-ai-endpoints langchain-community neo4j langchain-core"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "from typing import Any, Dict, List, Optional, Tuple, Type\n",
        "\n",
        "from langchain_community.graphs import Neo4jGraph\n",
        "from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars\n",
        "from langchain_core.tools import tool\n",
        "from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
        "\n",
        "from langchain.agents import AgentExecutor\n",
        "from langchain.agents.format_scratchpad import \\\n",
        "    format_to_openai_function_messages\n",
        "from langchain.agents.output_parsers.openai_tools import \\\n",
        "    OpenAIToolsAgentOutputParser\n",
        "from langchain.callbacks.manager import (AsyncCallbackManagerForToolRun,\n",
        "                                         CallbackManagerForToolRun)\n",
        "from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
        "from langchain.pydantic_v1 import BaseModel, Field\n",
        "from langchain.schema import AIMessage, HumanMessage\n",
        "from langchain.tools import BaseTool\n",
        "from langchain.tools.render import format_tool_to_openai_function"
      ],
      "metadata": {
        "id": "JMIjgBS-2lWj"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Build a knowledge graph-based agent with Llama 3.1, NVIDIA NIM, and LangChain\n",
        "## Leverage Llama-3.1 native function-calling capabilities to retrieve structured data from a knowledge graph to power your RAG applications\n",
        "\n",
        "While most people focus on RAG over unstructured text, such as company documents or documentation, I am pretty bullish on retrieval systems over structured information, particularly knowledge graphs. There has been a lot of excitement about GraphRAG, specifically Microsoft's implementation. However, in their implementation, the input data is unstructured text in the form of documents, which is transformed into a knowledge graph using a Large Language Model (LLM).\n",
        "\n",
        "In this blog post, we will show how to implement a retriever over a knowledge graph containing structured information from FDA Adverse Event Reporting System (FAERS), which offers the information about drug adverse events. If you have ever tinkered with knowledge graphs and retrieval, your first thought might be to use an LLM to generate database queries to retrieve relevant information from a knowledge graph to answer a given question. However, database query generation using LLMs is still evolving and may not yet offer the most consistent or robust solution. So, what are the viable alternatives at the moment?\n",
        "\n",
        "In my opinion, the best solution at the moment is the so-called dynamic query generation. Rather than relying entirely on an LLM to generate the complete query, this method employs a logic layer that deterministically generates a database query from predefined input parameters.his solution can be implemented using an LLM with function-calling support. The advantage of using a function-calling feature lies in the ability to define to an LLM how it should prepare a structured input to a function. This approach ensures that the query generation process is controlled and consistent while allowing for user input flexibility.\n",
        "\n",
        "![image](https://cdn-images-1.medium.com/max/1200/1*5N_tHHvD-rY9ZP_nnANDMw.png)\n",
        "\n",
        "The image illustrates a  process of a understanding a user's question to retrieve specific information. The flow involves three main steps:\n",
        "A user asks a question about common side effects of the drug Lyrica for people under 35 years old.\n",
        "\n",
        "The LLM decides which function to call and the parameters needed. In this example, it chose a function named \"side_effects\" with parameters including the drug \"Lyrica\" and a maximum age of 35.\n",
        "\n",
        "The identified function and parameters are used to deterministically and dynamically generate a database query (Cypher) statement to retrieve relevant information.\n",
        "\n",
        "Function-calling support is vital for advanced LLM use cases, such as allowing LLMs to use multiple retrievers based on user intent or building multi-agent flows. I have written some articles using commercial LLMs with native function-calling support. However, in this blog post, we will use Llama-3.1, a superior open-source LLM with native function-calling support that was released just recently.\n",
        "## Setting up the knowledge graph\n",
        "We will use Neo4j, which is a native graph database to store the adverse event information. You can set up a free cloud Sandbox project that comes with pre-populated FAERS by following [this link](https://sandbox.neo4j.com/?usecase=healthcare-analytics).\n",
        "The instantiated database instance has a graph with the following schema.\n",
        "\n",
        "![image](https://cdn-images-1.medium.com/max/800/1*hM90ShEOOWhbQ-6_OcrWYg.png)\n",
        "\n",
        "The schema centers on the Case node, which links various aspects of a drug safety report, including the drugs involved, reactions experienced, outcomes, and therapies prescribed. Each drug is characterized by whether it is primary, secondary, concomitant, or interacting. Cases are also associated with information about the manufacturer, the age group of the patient, and the source of the report. This schema allows for tracking and analyzing the relationships between drugs, their reactions, and outcomes in a structured manner.\n",
        "\n",
        "We'll start by creating a connection to the database by instantiating a Neo4jGraph object."
      ],
      "metadata": {
        "id": "Laaac2KIAcY3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "os.environ[\"NEO4J_URI\"] = \"bolt://18.206.157.187:7687\"\n",
        "os.environ[\"NEO4J_USERNAME\"] = \"neo4j\"\n",
        "os.environ[\"NEO4J_PASSWORD\"] = \"elevation-reservist-thousands\"\n",
        "\n",
        "graph = Neo4jGraph(refresh_schema=False)"
      ],
      "metadata": {
        "id": "67TCB9920Sqq"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Setting up the LLM environment\n",
        "There are many options to host open-source LLMs like the Llama-3.1. In this blog post, we will use the [NVIDIA API catalog](https://www.nvidia.com/en-us/ai/#referrer=ai-subdomain?ncid=ref-inpa-144886), which provides NVIDIA NIM inference microservices and supports function calling for llama 3.1 models. When you create an account you get 1000 tokens, which is more than enough to complete this blog post. You'll need to create an API key and copy it to the notebook."
      ],
      "metadata": {
        "id": "zwuqME9vA6nW"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "os.environ[\"NVIDIA_API_KEY\"] = \"nvapi-\"\n",
        "llm = ChatNVIDIA(model=\"meta/llama-3.1-70b-instruct\")"
      ],
      "metadata": {
        "id": "jnSoDL_G03ly"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "We'll be using the **llama-3.1–70b** as the 8b version has some hiccups with optional parameters in function definitions.\n",
        "The nice thing about NVIDIA NIM microservices is that you can easily host them locally if you have security or other concerns, so it's really easily swappable and you only need to add a url parameter to the LLM configuration."
      ],
      "metadata": {
        "id": "sMP-YGJtBF6N"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# connect to an local NIM running at localhost:8000,\n",
        "# specifying a specific model\n",
        "# llm = ChatNVIDIA(\n",
        "#  base_url=\"http://localhost:8000/v1\",\n",
        "#  model=\"meta/llama-3.1-70b-instruct\"\n",
        "# )"
      ],
      "metadata": {
        "id": "QZk-CE75rS_Z"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "result = llm.invoke(\"How to use LLMs in combination with Graph Databases? Be concise!\")\n",
        "print(result.content)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9XbET6XqvVUr",
        "outputId": "b6ddf191-3229-4173-f5ff-e74e244c751f"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Here's a concise guide on using Large Language Models (LLMs) with Graph Databases:\n",
            "\n",
            "**Why combine LLMs with Graph Databases?**\n",
            "\n",
            "* Enhance semantic search and querying capabilities\n",
            "* Leverage graph-based relationships for more accurate and informative results\n",
            "* Augment graph data with text-based insights and entity recognition\n",
            "\n",
            "**How to combine LLMs with Graph Databases:**\n",
            "\n",
            "1. **Text preprocessing**: Use LLMs to preprocess text data (e.g., NLP, entity recognition, sentiment analysis) before ingesting it into the graph database.\n",
            "2. **Graph embedding**: Use LLMs to generate graph embeddings, which can be used to represent nodes and edges in the graph database.\n",
            "3. **Query augmentation**: Use LLMs to augment graph queries with natural language understanding, enabling more expressive and flexible querying.\n",
            "4. **Graph-based knowledge graph completion**: Use LLMs to predict missing relationships and entities in the graph database, enhancing knowledge graph completion.\n",
            "5. **Entity disambiguation**: Use LLMs to disambiguate entities in the graph database, improving entity recognition and resolution.\n",
            "\n",
            "**Tools and frameworks:**\n",
            "\n",
            "* Graph databases (e.g., Neo4j, Amazon Neptune)\n",
            "* LLM libraries (e.g., Hugging Face Transformers, spaCy)\n",
            "* Integration frameworks (e.g., GraphQL, Apache Spark)\n",
            "\n",
            "By combining LLMs with Graph Databases, you can unlock powerful insights and applications, such as:\n",
            "* Improved customer 360-degree views\n",
            "* Enhanced knowledge graph-based search and recommendation systems\n",
            "* More accurate and informative graph-based analytics and visualizations\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Tool definition\n",
        "In this example, we will configure a single tool with four optional parameters. Based on those parameters, we will construct a corresponding Cypher statement that will be used to retrieve the relevant information from the knowledge graph.\n",
        "\n",
        "First we will define some utility functions"
      ],
      "metadata": {
        "id": "6fFICQWMBQHo"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "graph.query(\n",
        "    \"CREATE FULLTEXT INDEX drug IF NOT EXISTS FOR (d:Drug) ON EACH [d.name];\"\n",
        ")\n",
        "graph.query(\n",
        "    \"CREATE FULLTEXT INDEX manufacturer IF NOT EXISTS FOR (d:Manufacturer) ON EACH [d.manufacturerName];\"\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RQukw9a34nIY",
        "outputId": "68ccf9e1-b54b-4489-e768-1b23f5fdf80e"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[]"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def generate_full_text_query(input: str) -> str:\n",
        "    \"\"\"\n",
        "    Generate a full-text search query for a given input string.\n",
        "\n",
        "    This function constructs a query string suitable for a full-text search.\n",
        "    It processes the input string by splitting it into words and appending a\n",
        "    similarity threshold (~2) to each word, then combines them using the AND\n",
        "    operator. Useful for mapping movies and people from user questions\n",
        "    to database values, and allows for some misspelings.\n",
        "    \"\"\"\n",
        "    full_text_query = \"\"\n",
        "    words = [el for el in remove_lucene_chars(input).split() if el]\n",
        "    for word in words[:-1]:\n",
        "        full_text_query += f\" {word}~2 AND\"\n",
        "    full_text_query += f\" {words[-1]}~2\"\n",
        "    return full_text_query.strip()\n",
        "\n",
        "\n",
        "candidate_query = \"\"\"\n",
        "CALL db.index.fulltext.queryNodes($index, $fulltextQuery, {limit: $limit})\n",
        "YIELD node\n",
        "RETURN coalesce(node.manufacturerName, node.name) AS candidate,\n",
        "       labels(node)[0] AS label\n",
        "\"\"\"\n",
        "\n",
        "\n",
        "def get_candidates(input: str, type: str, limit: int = 3) -> List[Dict[str, str]]:\n",
        "    \"\"\"\n",
        "    Retrieve a list of candidate entities from database based on the input string.\n",
        "\n",
        "    This function queries the Neo4j database using a full-text search. It takes the\n",
        "    input string, generates a full-text query, and executes this query against the\n",
        "    specified index in the database. The function returns a list of candidates\n",
        "    matching the query, with each candidate being a dictionary containing their name\n",
        "    (or title) and label (either 'Person' or 'Movie').\n",
        "    \"\"\"\n",
        "    ft_query = generate_full_text_query(input)\n",
        "    candidates = graph.query(\n",
        "        candidate_query, {\"fulltextQuery\": ft_query, \"index\": type, \"limit\": limit}\n",
        "    )\n",
        "    return candidates"
      ],
      "metadata": {
        "id": "046n-sfaKge1"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "get_candidates(\"voriconazol\", \"drug\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4ENXeeziLVhx",
        "outputId": "fd4c3e76-908d-4885-ee6e-bb390578723e"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'candidate': 'VORICONAZOLE', 'label': 'Drug'}]"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Specifically, our tool will be able to identify the most frequent side effects based on input drug, age, and the drug manufacturer."
      ],
      "metadata": {
        "id": "vl00HXTJBXYL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "@tool\n",
        "def get_side_effects(\n",
        "    drug: Optional[str] = Field(\n",
        "        description=\"disease mentioned in the question. Return None if no mentioned.\"\n",
        "    ),\n",
        "    min_age: Optional[int] = Field(\n",
        "        description=\"Minimum age of the patient. Return None if no mentioned.\"\n",
        "    ),\n",
        "    max_age: Optional[int] = Field(\n",
        "        description=\"Maximum age of the patient. Return None if no mentioned.\"\n",
        "    ),\n",
        "    manufacturer: Optional[str] = Field(\n",
        "        description=\"manufacturer of the drug. Return None if no mentioned.\"\n",
        "    ),\n",
        "):\n",
        "    \"\"\"Useful for when you need to find common side effects.\"\"\"\n",
        "    params = {}\n",
        "    filters = []\n",
        "    side_effects_base_query = \"\"\"\n",
        "    MATCH (c:Case)-[:HAS_REACTION]->(r:Reaction), (c)-[:IS_PRIMARY_SUSPECT]->(d:Drug)\n",
        "    \"\"\"\n",
        "    if drug and isinstance(drug, str):\n",
        "        candidate_drugs = [el[\"candidate\"] for el in get_candidates(drug, \"drug\")]\n",
        "        if not candidate_drugs:\n",
        "            return \"The mentioned drug was not found\"\n",
        "        filters.append(\"d.name IN $drugs\")\n",
        "        params[\"drugs\"] = candidate_drugs\n",
        "\n",
        "    if min_age and isinstance(min_age, int):\n",
        "        filters.append(\"c.age > $min_age \")\n",
        "        params[\"min_age\"] = min_age\n",
        "    if max_age and isinstance(max_age, int):\n",
        "        filters.append(\"c.age < $max_age \")\n",
        "        params[\"max_age\"] = max_age\n",
        "    if manufacturer and isinstance(manufacturer, str):\n",
        "        candidate_manufacturers = [\n",
        "            el[\"candidate\"] for el in get_candidates(manufacturer, \"manufacturer\")\n",
        "        ]\n",
        "        if not candidate_manufacturers:\n",
        "            return \"The mentioned manufacturer was not found\"\n",
        "        filters.append(\n",
        "            \"EXISTS {(c)<-[:REGISTERED]-(:Manufacturer {manufacturerName: $manufacturer})}\"\n",
        "        )\n",
        "        params[\"manufacturer\"] = candidate_manufacturers[0]\n",
        "\n",
        "    if filters:\n",
        "        side_effects_base_query += \" WHERE \"\n",
        "        side_effects_base_query += \" AND \".join(filters)\n",
        "    side_effects_base_query += \"\"\"\n",
        "    RETURN d.name AS drug, r.description AS side_effect, count(*) AS count\n",
        "    ORDER BY count DESC\n",
        "    LIMIT 10\n",
        "    \"\"\"\n",
        "    print(f\"Using parameters: {params}\")\n",
        "    data = graph.query(side_effects_base_query, params=params)\n",
        "    return data"
      ],
      "metadata": {
        "id": "jLyPt3vGUdpz"
      },
      "execution_count": 10,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The get_side_effectsfunction is designed to retrieve common side effects of drugs from a knowledge graph using specified search criteria. It accepts optional parameters for drug name, patient age range, and drug manufacturer to customize the search. Each parameter has a description that is passed to an LLM along with the function description, enabling the LLM to understand how to use them. The function then constructs a dynamic Cypher query based on the provided inputs, executes this query against the knowledge graph, and returns the resulting side effects data.\n",
        "\n",
        "Let's test the function."
      ],
      "metadata": {
        "id": "nrLjRDHrBaGe"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "get_side_effects(\"lyrica\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XXovEQAiN3Mo",
        "outputId": "fe57375f-6f06-47d2-a942-4ba1886acc43"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Using parameters: {'drugs': ['LYRICA', 'LYRICA CR']}\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/langchain_core/_api/deprecation.py:139: LangChainDeprecationWarning: The method `BaseTool.__call__` was deprecated in langchain-core 0.1.47 and will be removed in 0.3.0. Use invoke instead.\n",
            "  warn_deprecated(\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[{'drug': 'LYRICA', 'side_effect': 'Pain', 'count': 32},\n",
              " {'drug': 'LYRICA', 'side_effect': 'Fall', 'count': 21},\n",
              " {'drug': 'LYRICA',\n",
              "  'side_effect': 'Intentional product use issue',\n",
              "  'count': 20},\n",
              " {'drug': 'LYRICA', 'side_effect': 'Insomnia', 'count': 19},\n",
              " {'drug': 'LYRICA', 'side_effect': 'Feeling abnormal', 'count': 18},\n",
              " {'drug': 'LYRICA', 'side_effect': 'Drug ineffective', 'count': 18},\n",
              " {'drug': 'LYRICA', 'side_effect': 'Memory impairment', 'count': 17},\n",
              " {'drug': 'LYRICA', 'side_effect': 'Withdrawal syndrome', 'count': 17},\n",
              " {'drug': 'LYRICA', 'side_effect': 'Malaise', 'count': 16},\n",
              " {'drug': 'LYRICA', 'side_effect': 'Intentional product misuse', 'count': 15}]"
            ]
          },
          "metadata": {},
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Our tool first mapped the \"lyrica\" drug mentioned in the question to \"['LYRICA', 'LYRICA CR']\" values in the knowledge graph and then executed corresponding Cypher statement to find the most frequent side effects.\n",
        "## Graph-based LLM Agent\n",
        "The only thing left to do is configure an LLM agent that can use the defined tool to answer questions about the drug's side effects.\n",
        "\n",
        "![image](https://cdn-images-1.medium.com/max/800/1*5Q4y5emOAhR7kw2L7rpeaw.png)\n",
        "\n",
        "The image depicts a user interacting with a Llama-3.1 agent to inquire about drug side effects. The agent accesses a side effects tool that retrieves information from a knowledge graph to provide the user with the relevant data.\n",
        "\n",
        "We'll start by defining the prompt template."
      ],
      "metadata": {
        "id": "mLUQn3JaBdoT"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "prompt = ChatPromptTemplate.from_messages(\n",
        "    [\n",
        "        (\n",
        "            \"system\",\n",
        "            \"You are a helpful assistant that finds information about common side effects. \"\n",
        "            \"If tools require follow up questions, \"\n",
        "            \"make sure to ask the user for clarification. Make sure to include any \"\n",
        "            \"available options that need to be clarified in the follow up questions \"\n",
        "            \"Do only the things the user specifically requested. \",\n",
        "        ),\n",
        "        MessagesPlaceholder(variable_name=\"chat_history\"),\n",
        "        (\"user\", \"{input}\"),\n",
        "        MessagesPlaceholder(variable_name=\"agent_scratchpad\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "def _format_chat_history(chat_history: List[Tuple[str, str]]):\n",
        "    buffer = []\n",
        "    for human, ai in chat_history:\n",
        "        buffer.append(HumanMessage(content=human))\n",
        "        buffer.append(AIMessage(content=ai))\n",
        "    return buffer"
      ],
      "metadata": {
        "id": "_BnUZve2Mokm"
      },
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The prompt template includes the system message, optional chat history, and user input. The agent_scratchpad is reserved for the LLM, as it sometimes needs to perform multiple steps to answer the question, like executing and retrieving information from tools.\n",
        "\n",
        "The LangChain library makes it straightforward to add tools to the LLM by using the bind_tools method."
      ],
      "metadata": {
        "id": "G2OqWRznBmQS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "tools = [get_side_effects]\n",
        "llm_with_tools = llm.bind_tools(tools=tools)\n",
        "\n",
        "agent = (\n",
        "    {\n",
        "        \"input\": lambda x: x[\"input\"],\n",
        "        \"chat_history\": lambda x: _format_chat_history(x[\"chat_history\"])\n",
        "        if x.get(\"chat_history\")\n",
        "        else [],\n",
        "        \"agent_scratchpad\": lambda x: format_to_openai_function_messages(\n",
        "            x[\"intermediate_steps\"]\n",
        "        ),\n",
        "    }\n",
        "    | prompt\n",
        "    | llm_with_tools\n",
        "    | OpenAIToolsAgentOutputParser()\n",
        ")\n",
        "\n",
        "\n",
        "# Add typing for input\n",
        "class AgentInput(BaseModel):\n",
        "    input: str\n",
        "    chat_history: List[Tuple[str, str]] = Field(\n",
        "        ..., extra={\"widget\": {\"type\": \"chat\", \"input\": \"input\", \"output\": \"output\"}}\n",
        "    )\n",
        "\n",
        "\n",
        "class Output(BaseModel):\n",
        "    output: Any\n",
        "\n",
        "\n",
        "agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True).with_types(\n",
        "    input_type=AgentInput, output_type=Output\n",
        ")"
      ],
      "metadata": {
        "id": "D8icN7K9O6IT"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The agent processes input through a series of transformations and handlers that format the chat history, apply the LLM with the bound tools, and parse the output. Finally, the agent is set up with an executor that manages the execution flow, specifies input and output types, and includes verbosity settings for detailed logging during execution.\n",
        "\n",
        "Let's now test the agent."
      ],
      "metadata": {
        "id": "23IsRLZCBphb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "agent_executor.invoke(\n",
        "    {\n",
        "        \"input\": \"What are the most common side effects when using lyrica for people below 35 years old?\"\n",
        "    }\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Gi4IvWKOPocK",
        "outputId": "1c30955f-cb0f-4680-84f1-9e226d8e5475"
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "\n",
            "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
            "\u001b[32;1m\u001b[1;3m\n",
            "Invoking: `get_side_effects` with `{'drug': 'lyrica', 'max_age': 35}`\n",
            "\n",
            "\n",
            "\u001b[0mUsing parameters: {'drugs': ['LYRICA', 'LYRICA CR'], 'max_age': 35}\n",
            "\u001b[36;1m\u001b[1;3m[{'drug': 'LYRICA', 'side_effect': 'Sinusitis', 'count': 1}, {'drug': 'LYRICA', 'side_effect': 'Hypoacusis', 'count': 1}, {'drug': 'LYRICA', 'side_effect': 'Fall', 'count': 1}, {'drug': 'LYRICA', 'side_effect': 'Brain injury', 'count': 1}, {'drug': 'LYRICA', 'side_effect': 'Amnesia', 'count': 1}, {'drug': 'LYRICA', 'side_effect': 'Off label use', 'count': 1}, {'drug': 'LYRICA', 'side_effect': 'Impaired quality of life', 'count': 1}, {'drug': 'LYRICA', 'side_effect': 'Somnolence', 'count': 1}, {'drug': 'LYRICA', 'side_effect': 'Drug ineffective for unapproved indication', 'count': 1}]\u001b[0m\u001b[32;1m\u001b[1;3mBased on the provided data, the most common side effects of Lyrica for people below 35 years old are sinusitis, hypoacusis, fall, brain injury, amnesia, off-label use, impaired quality of life, somnolence, and drug ineffective for unapproved indication.\u001b[0m\n",
            "\n",
            "\u001b[1m> Finished chain.\u001b[0m\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input': 'What are the most common side effects when using lyrica for people below 35 years old?',\n",
              " 'output': 'Based on the provided data, the most common side effects of Lyrica for people below 35 years old are sinusitis, hypoacusis, fall, brain injury, amnesia, off-label use, impaired quality of life, somnolence, and drug ineffective for unapproved indication.'}"
            ]
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "agent_executor.invoke(\n",
        "    {\n",
        "        \"input\": \"What are the most common side effects when for drugs manufactured by acadia?\"\n",
        "    }\n",
        ")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8KXmv4hcvKy_",
        "outputId": "a5a759c6-c123-4b87-bbf0-5f3d307f8b0b"
      },
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "\n",
            "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
            "\u001b[32;1m\u001b[1;3m\n",
            "Invoking: `get_side_effects` with `{'manufacturer': 'acadia'}`\n",
            "\n",
            "\n",
            "\u001b[0mUsing parameters: {'manufacturer': 'ACADIA PHARMACEUTICALS'}\n",
            "\u001b[36;1m\u001b[1;3m[{'drug': 'NUPLAZID', 'side_effect': 'Hallucination', 'count': 13}, {'drug': 'NUPLAZID', 'side_effect': 'Confusional state', 'count': 7}, {'drug': 'NUPLAZID', 'side_effect': 'Fall', 'count': 6}, {'drug': 'NUPLAZID', 'side_effect': 'Delusion', 'count': 5}, {'drug': 'NUPLAZID', 'side_effect': 'Gait disturbance', 'count': 5}, {'drug': 'NUPLAZID', 'side_effect': 'Fatigue', 'count': 4}, {'drug': 'NUPLAZID', 'side_effect': 'Abnormal behaviour', 'count': 3}, {'drug': 'NUPLAZID', 'side_effect': 'Product dose omission issue', 'count': 3}, {'drug': 'NUPLAZID', 'side_effect': 'Agitation', 'count': 3}, {'drug': 'NUPLAZID', 'side_effect': 'Death', 'count': 3}]\u001b[0m\u001b[32;1m\u001b[1;3mThe most common side effects for drugs manufactured by Acadia are: Hallucination, Confusional state, Fall, Delusion, Gait disturbance, Fatigue, Abnormal behaviour, Product dose omission issue, Agitation, and Death.\u001b[0m\n",
            "\n",
            "\u001b[1m> Finished chain.\u001b[0m\n"
          ]
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'input': 'What are the most common side effects when for drugs manufactured by acadia?',\n",
              " 'output': 'The most common side effects for drugs manufactured by Acadia are: Hallucination, Confusional state, Fall, Delusion, Gait disturbance, Fatigue, Abnormal behaviour, Product dose omission issue, Agitation, and Death.'}"
            ]
          },
          "metadata": {},
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "The LLM identified it needs to use the get_side_effects function with appropriate arguments. The function then dynamically generates a Cypher statement, fetches the relevant information, and returns it to the LLM to generate the final answer.\n",
        "## Summary\n",
        "Function calling capabilities are a powerful addition to open-source models like Llama 3.1, enabling more structured and controlled interactions with external data sources and tools. Beyond just querying unstructured documents, graph-based agents offer exciting possibilities for interacting with knowledge graphs and structured data. The ease of hosting these models using platforms like [NVIDIA NIM microservices](https://www.nvidia.com/en-us/ai/#referrer=ai-subdomain?ncid=ref-inpa-144886) makes them increasingly accessible."
      ],
      "metadata": {
        "id": "Xaz56SVUBvPz"
      }
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "_PJqY5fjxslb"
      },
      "execution_count": 15,
      "outputs": []
    }
  ]
}